1 module jupyter.wire.connection;
2 
3 
4 import jupyter.wire.message: Message;
5 import zmqd: Socket, Frame;
6 import std.typecons: Nullable;
7 
8 
9 ConnectionInfo fileNameToConnectionInfo(in string fileName) @safe {
10     import std.file: readText;
11     return ConnectionInfo(readText(fileName));
12 }
13 
14 
15 struct ConnectionInfo {
16     import mir.serde: serdeKeys;
17     @serdeKeys("signature_scheme") string signatureScheme;
18                               string transport;
19     @serdeKeys("stdin_port")       ushort stdinPort;
20     @serdeKeys("control_port")     ushort controlPort;
21     @serdeKeys("iopub_port")       ushort ioPubPort;
22     @serdeKeys("hb_port")          ushort hbPort;
23     @serdeKeys("shell_port")       ushort shellPort;
24                               string key;
25                               string ip;
26 
27     this(in string json) @safe pure {
28         import asdf: deserialize;
29         this = () @trusted { return json.deserialize!ConnectionInfo; }();
30     }
31 
32     string uri(ushort port) @safe pure const {
33         import std.conv: text;
34         return text(transport, "://", ip, ":", port);
35     }
36 }
37 
38 
39 struct Sockets {
40     import jupyter.wire.message: MessageHeader;
41     import zmqd: Socket, SocketType;
42     import std.json: JSONValue;
43     import std.concurrency: Tid;
44 
45     ConnectionInfo connectionInfo;
46     Socket shell, control, stdin, ioPub;
47     Tid heartbeatTid;
48 
49     static struct Stop{}
50     static struct Done{}
51 
52     this(ConnectionInfo ci) @safe {
53         import zmqd: SocketType;
54 
55         this.connectionInfo = ci;
56 
57         initSocket(shell,     SocketType.router, ci, ci.shellPort);
58         initSocket(control,   SocketType.router, ci, ci.controlPort);
59         initSocket(stdin,     SocketType.router, ci, ci.stdinPort);
60         initSocket(ioPub,     SocketType.pub,    ci, ci.ioPubPort);
61 
62         startHeartbeatLoop;
63     }
64 
65     ~this() {
66         stopHeartbeatLoop;
67     }
68 
69     private static void initSocket(ref Socket socket, in SocketType socketType, in ConnectionInfo ci, in ushort port) @safe {
70         import zmqd: Socket;
71         socket = Socket(socketType);
72         socket.bind(ci.uri(port));
73     }
74 
75     void send(ref Socket socket, Message message) @safe {
76         sendStrings(socket, message.toStrings(connectionInfo.key));
77     }
78 
79     void publish(in MessageHeader parentHeader, in string msgType, JSONValue content) @safe {
80         import jupyter.wire.message: pubMessage;
81         send(ioPub, pubMessage(parentHeader, msgType, content));
82     }
83 
84     /**
85        "Send" stdout output to jupyter notebook
86      */
87     void stdout(in MessageHeader parentHeader, in string stdout) @safe {
88         JSONValue content;
89         content["name"] = "stdout";
90         content["text"] = stdout;
91         publish(parentHeader, "stream", content);
92     }
93 
94     private void startHeartbeatLoop() @safe {
95         import std.concurrency: spawn, thisTid;
96         heartbeatTid = () @trusted { return spawn(&heartbeatLoop, thisTid, connectionInfo); }();
97     }
98 
99     private void stopHeartbeatLoop() @trusted {
100         import std.concurrency: send, receiveOnly;
101 
102         // for some reason the destructor is getting called with dmd
103         // 2.096.0 when creating the kernel
104         if(heartbeatTid == heartbeatTid.init) return;
105 
106         heartbeatTid.send(Stop());
107         receiveOnly!Done;
108         heartbeatTid = Tid.init;
109     }
110 
111     private static void heartbeatLoop(Tid parentTid, ConnectionInfo connectionInfo) @safe nothrow {
112         import jupyter.wire.log: log;
113         import zmqd: ZmqException;
114         import std.concurrency: receiveTimeout, send;
115         import std.datetime: msecs;
116 
117         try {
118             auto socket = Socket(SocketType.rep);
119             socket.bind(connectionInfo.uri(connectionInfo.hbPort));
120 
121             ubyte[1024] buf;
122 
123             for(bool stop; !stop;) {
124                 () @trusted {
125                     receiveTimeout(
126                         10.msecs,
127                         (Stop _) {
128                             stop = true;
129                         },
130                         );
131                 }();
132 
133                 const ret /*size, bool*/ = socket.maybeReceive(buf);
134                 const length = ret[0];
135                 if(length) socket.send(buf[0 .. length]);
136             }
137 
138             () @trusted { parentTid.send(Done()); }();
139         } catch(Exception e) {
140             log("ERROR in heartbeat thread: ", e.msg);
141         }
142     }
143 }
144 
145 // workaround for zmqd bug
146 // https://github.com/kyllingstad/zmqd/issues/22
147 auto maybeReceive(T)(ref Socket socket, ref T arg) @trusted {
148     import zmqd: ZmqException;
149     import deimos.zmq.zmq: zmq_errno;
150     import std.typecons: tuple;
151     import core.stdc.errno: EAGAIN;
152 
153     try {
154         return socket.tryReceive(arg);
155     } catch(ZmqException e) {
156         if(zmq_errno != EAGAIN)
157             throw e;
158 
159         return tuple(cast(size_t) 0, false);
160     }
161 }
162 
163 
164 /**
165    Receive a message on the given zeromq socket without blocking.
166    If there are no messages to be received, returns a null message.
167  */
168 Nullable!Message recvRequestMessage(ref Socket socket) @safe {
169     import jupyter.wire.connection: recvStrings;
170     import jupyter.wire.message: Message;
171     import std.typecons: Nullable, nullable;
172 
173     const requestStrings = socket.recvStrings;
174     if(requestStrings is null) return Nullable!Message();
175 
176     return nullable(Message(requestStrings));
177 }
178 
179 // The shell and control sockets receive 6 or more strings at time
180 // See https://jupyter-client.readthedocs.io/en/stable/messaging.html#wire-protocol
181 private string[] recvStrings(ref Socket socket) @safe {
182     import zmqd: Frame;
183 
184     string[] strings;
185 
186     do {
187         auto frame = Frame();
188         import jupyter.wire.log: log;
189         const ret /*size, bool*/ = socket.maybeReceive(frame);
190         if(!ret[1]) return [];
191         strings ~= cast(string) frame.data.idup;
192     } while(socket.more);
193 
194     return strings;
195 }
196 
197 // Send multiple strings at once over ZeroMQ
198 private void sendStrings(ref Socket socket, in string[] lines) @safe {
199     foreach(line; lines[0 .. $-1])
200         socket.send(line, true /*more*/);
201     socket.send(lines[$-1], false /*more*/);
202 }